"""
 @Author: zzgsg
 @time: 2024/9/28 20:32
"""

# Figure 3. Proportions of LET nodes where lookup error exceeds 0.1.
import matplotlib.pyplot as plt
import h5py
import numpy as np

fname = r'O4LUT_360nm_damf.h5'

f1 = h5py.File(fname, 'r')
LUT_parameters2 = {}
for i in f1.keys():
    LUT_parameters2[i] = f1[i][:]
LUT2 = LUT_parameters2.pop('LUT')

fname = r'O4LUT_360nm_damf_err.h5'

f1 = h5py.File(fname, 'r')
LUT_parameters = {}
for i in f1.keys():
    LUT_parameters[i] = f1[i][:]
LUT = LUT_parameters.pop('LUT')

diffk = {}
proportion = np.ones_like(LUT)
for i, key in enumerate(['sza', 'raa', 'elev', 'c', 'h', 's']):
    if key=='elev':
        diffk[key] = np.ones(LUT_parameters2[key].shape[0])
    else:
        diffk[key] = np.diff(LUT_parameters2[key])
    diffk[key] = diffk[key]/np.sum(diffk[key])
    expandlist=[0, 1, 2, 3, 4, 5]
    expandlist.pop(i)
    proportion = proportion*np.expand_dims(diffk[key], axis=expandlist)

plt.rcParams['font.size'] = 25
fig3, axs = plt.subplots(3, 2, figsize=(27, 15))
plt.subplots_adjust(top=0.98, bottom=0.06, right=0.97, left=0.05, hspace=0.3, wspace=0.1)
title = ['sza', 'raa', 'elev', 'c', 'h', 's']
colors = ['#D68F8D', '#6A9BCC', '#8FBC8F', '#F0B49D', '#A6A6D5', '#F5A6B1']
colors = ['#FF9999', '#66B3FF', '#99FF99', '#FFCC99', '#C2C2F0', '#FFB6C1']
colors = ['#E57373', '#64B5F6', '#81C784', '#FFB74D', '#9575CD', '#FF8A65']
colors = ['#F28C8C', '#6EA7E8', '#8ECB8C', '#FFBF70', '#AE89DB', '#FFA085']
order = ['a', 'b', 'c', 'd', 'e', 'f']
for i, key in enumerate(['sza', 'raa', 'elev', 'c', 'h', 's']):
    ax3 = axs.flatten()[i]
    ax3.spines['bottom'].set_linewidth(2)
    ax3.spines['left'].set_linewidth(2)
    ax3.spines['right'].set_linewidth(2)
    ax3.spines['top'].set_linewidth(2)
    ax3.grid(axis='y', linestyle='-.', linewidth=2)
    transposelist=[0, 1, 2, 3, 4, 5]
    transposelist[i]=0
    transposelist[0]=i
    data = LUT.transpose(transposelist)
    data = data.reshape(data.shape[0], -1)
    ax3.bar(np.arange(data.shape[0]+1), np.r_[np.sum(data>0.1, axis=-1)/data.shape[1], [np.sum(data>0.1)/(data.shape[1]*data.shape[0])]], color=colors[i])
    ax3.set_ylim([0, 0.15])
    ax3.set_yticks([0, 0.05, 0.1, 0.15])
    ax3.set_xlim([-0.5, data.shape[0]+0.5])
    ax3.set_xticks(np.arange(data.shape[0]+1))
    ax3.set_xticklabels(np.r_[LUT_parameters[key], ['all']])
    ax3.annotate('({})'.format(order[i]), xy=(0.05,0.9), xycoords='axes fraction',fontsize=20,
                 horizontalalignment='center', verticalalignment='top')
    # ax3.set_title(key)
    ax3.set_xlabel(key)
plt.savefig('temp/Figure3.eps')
plt.savefig('temp/Figure3.png')
plt.close()
# plt.bar(np.arange(11), np.sum(LUT2>0.1, axis=-1)/LUT2.shape[1])
# plt.ylim([0, 0.1])
# plt.xticks(np.arange(11), LUT_parameters['sza'])
# plt.show()
# plt.figure(figsize=(10,3))
# plt.bar(np.arange(11), np.sum(LUT2>0.1, axis=-1)/LUT2.shape[1])
# plt.ylim([0, 0.1])
# plt.xticks(np.arange(11), LUT_parameters['sza'])
# plt.show()